# coding=utf-8
import os
import time
import tensorflow as tf
from utils import get_data, data_hparams, GetEditDistance, decode_ctc
from keras.callbacks import ModelCheckpoint, TensorBoard, EarlyStopping, ReduceLROnPlateau, LambdaCallback
from sklearn.metrics import roc_auc_score
import numpy as np
import matplotlib.pyplot as plt
import warnings

dataLength = 1200  # 每次训练的数据数
RootPath = "E:/DeepLearning/dl_studio/bishe/my_ch_speech_recognition-master/log/"

os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
warnings.filterwarnings('ignore')
# gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.8)
# sess = tf.Session(configs=tf.ConfigProto(gpu_options=gpu_options))

# 0.准备训练所需数据------------------------------
data_args = data_hparams()
data_args.data_type = 'train'
data_args.data_path = '../data/'
data_args.mmcs = False
data_args.thchs30 = True
data_args.aishell = False
data_args.prime = False
data_args.stcmd = False
data_args.batch_size = 8
# data_args.data_length = 10000
data_args.data_length = None
data_args.shuffle = True
train_data = get_data(data_args)

count_length = train_data.countLength

# 0.准备验证所需数据------------------------------
data_args = data_hparams()
data_args.data_type = 'dev'
data_args.data_path = '../data/'
data_args.mmcs = False
data_args.thchs30 = True
data_args.aishell = False
data_args.prime = False
data_args.stcmd = False
data_args.batch_size = 1
# max 893
data_args.data_length = None
# data_args.data_length = 2000
data_args.shuffle = False
dev_data = get_data(data_args)
batch_num = len(train_data.wav_lst) // train_data.batch_size
dirPath0 = None
modelPath0 = None
for i in range(0, count_length // dataLength):
    # 创建文件夹
    if i >= 1:
        dirPath0 = RootPath + "model/am/epoch_" + str(i - 1)
        modelPath0 = dirPath0 + "/" + str(i - 1) + ".h5"  # 上一次的model
    dirPath = RootPath + "model/am/epoch_" + str(i)
    if os.path.exists(dirPath):
        pass
    else:
        os.mkdir(dirPath)
    # 进行迭代训练
    start = i * dataLength
    end = start + dataLength - 1
    train_data.starItem = start
    train_data.endItem = end
    # 重新获取数据
    train_data.adjustDataList()
    print("训练迭代数据轮:", str(i + 1))
    # 开始训练
    # 1.声学模型训练-----------------------------------
    from model_speech.cnn_ctc import Am, am_hparams

    # from model_speech.gru_ctc import Am, am_hparams

    am_args = am_hparams()
    am_args.vocab_size = len(train_data.am_vocab)
    am_args.gpu_nums = 1
    am_args.lr = 0.0008
    am_args.is_training = True
    am = Am(am_args)

    print("数据开始:", train_data.starItem)
    print("数据结束:", train_data.endItem)

    epochs = 240
    batch_num = len(train_data.wav_lst) // train_data.batch_size

    if dirPath0 is not None and os.path.exists(modelPath0):
        print('load acoustic model...')
        am.ctc_model.load_weights(modelPath0)  # 加载模型
    modelPath = dirPath + "/" + str(i) + ".h5"  # 这一次的model
    # 准备数据
    batch = train_data.get_am_batch()
    dev_batch = dev_data.get_am_batch()
    # 回调函数
    tensorBoard = TensorBoard(log_dir=RootPath + "tensorboard/am/" + str(int(time.time())), write_grads=True,
                              histogram_freq=0, update_freq="epoch")
    tensorBoard.set_model(am.ctc_model)

    earlyStopping = EarlyStopping(
        monitor='loss', min_delta=1e-5, patience=5, verbose=1
    )
    reduce_lr = ReduceLROnPlateau(monitor='loss', factor=0.1,
                                  patience=3, min_lr=0.00001)
    plot_loss_callback = LambdaCallback(
        on_epoch_end=lambda epoch, logs: plt.plot(np.arange(epoch),
                                                  logs['loss']))
    myCallBack = tf.keras.callbacks.LambdaCallback(
        on_epoch_end=lambda self, batch, logs: self.model.predict(self.validation_data))
    # 开始训练
    am.ctc_model.fit_generator(batch, steps_per_epoch=batch_num, initial_epoch=0, epochs=epochs,
                               callbacks=[tensorBoard, earlyStopping, reduce_lr
                                          ],
                               workers=1,
                               use_multiprocessing=False, validation_data=dev_batch, validation_steps=10)
    # 保存模型
    am.ctc_model.save_weights(modelPath)
    # 测试准确率
    word_error_num = 0
    word_num = 0
    with open("../LogDir/Logs/log/logout.txt", "a") as file:
        file.write("="*20+"\n")
    file.close()
    j = 0
    # <<<<初始化>>>> 注意 必须初始化dev 否则会出现对应失败
    dev_batch = dev_data.get_am_batch()
    for item in range(5):
        inputs, _ = next(dev_batch)
        x = inputs['the_inputs']
        result = am.model.predict(x)
        # print(result.shape)
        # print("============")
        # print(len(dev_data.am_vocab))
        # print(len(train_data.am_vocab))
        # result = result.reshape(result.shape[1], result.shape[0], result.shape[2])
        # print(result.shape)
        _, result = decode_ctc(result, train_data.am_vocab)
        label = dev_data.pny_lst[j]
        j += 1
        with open("../LogDir/Logs/log/logout.txt", "a") as file:
            file.write("预测："+','.join(result)+"\n")
            file.write("实际："+','.join(label)+"\n")
        file.close()
        # 计算两个拼音的差距
        word_error_num += min(len(label), GetEditDistance(label, result))
        word_num += len(label)
    print('词错误率：', (word_error_num / word_num))
    strLine = '【第' + str(i) + '轮】词错误率：' + str((word_error_num / word_num))
    # 每次追加记录
    with open("../LogDir/Logs/log/logout.txt", "a") as file:
        file.write(strLine + "\n")
    file.close()
    with open("../LogDir/Logs/log/logout.txt", "a") as file:
        file.write("="*20+"\n")
    file.close()

print("=================================")
print("=================================")
print("=================================")
print("=================================")
print("=================================")
print("=================================")
print("=================================")
print("声学模型学习完毕")
print("=================================")
print("=================================")
print("=================================")
print("=================================")
print("=================================")
print("=================================")
print("=================================")
print("=================================")

# checkpoint
# ckpt = "model_MMCS_{epoch:02d}-{val_loss:.2f}.h5"
# checkpoint = ModelCheckpoint(os.path.join('./checkpoint/Hai', ckpt), monitor='val_loss',
#                              save_weights_only=False,
#                              verbose=1,
#                              save_best_only=True)
#
# ckpt_pi = "model_{epoch:02d}-{loss:.2f}.h5"
# checkpointPi = ModelCheckpoint(os.path.join('./checkpointPi/Hai', ckpt_pi), monitor='loss',
#                                save_weights_only=False,
#                                verbose=1,
#                                save_best_only=True)


# 开始训练
# 2.语言模型训练-------------------------------------------
from model_language.transformer import Lm, lm_hparams

lm_args = lm_hparams()
lm_args.num_heads = 8
lm_args.num_blocks = 6
lm_args.input_vocab_size = len(train_data.pny_vocab)
lm_args.label_vocab_size = len(train_data.han_vocab)
lm_args.max_length = 500
lm_args.hidden_units = 512
lm_args.dropout_rate = 0.2
lm_args.lr = 0.0003
lm_args.is_training = True
lm = Lm(lm_args)

epochs = 50
with lm.graph.as_default():
    saver = tf.train.Saver()
with tf.Session(graph=lm.graph) as sess:
    merged = tf.summary.merge_all()
    sess.run(tf.global_variables_initializer())
    add_num = 0
    # if os.path.exists('logs_lm/checkpoint'):
    #     print('loading language model...')
    #     latest = tf.train.latest_checkpoint('logs_lm')
    #     add_num = int(latest.split('_')[-1])
    #     saver.restore(sess, latest)
    writer = tf.summary.FileWriter('logs_lm/tensorboard/Hai', tf.get_default_graph())

    for k in range(epochs):
        total_loss = 0
        batch = train_data.get_lm_batch()
        for i in range(batch_num):
            input_batch, label_batch = next(batch)
            if len(np.shape(label_batch)) < 2:
                print(label_batch)
                continue
            feed = {lm.x: input_batch, lm.y: label_batch}
            cost, _ = sess.run([lm.mean_loss, lm.train_op], feed_dict=feed)
            total_loss += cost
            print("cost=>", cost)
            if (k * batch_num + i) % 10 == 0:
                rs = sess.run(merged, feed_dict=feed)
                writer.add_summary(rs, k * batch_num + i)
        print('epochs', k + 1, ': average loss = ', total_loss / batch_num)
    saver.save(sess, 'log/model/lm/%d_time/modelGOBAL_%d' % (time.time(), (epochs + add_num)))
    writer.close()
